{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b076bd1a-b236-4fbc-953d-8295b25122ae",
   "metadata": {},
   "source": [
    "# 🎶 Music Generation - MuseGAN"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9235cbd1-f136-411c-88d9-f69f270c0b96",
   "metadata": {},
   "source": [
    "In this notebook, we'll walk through the steps required to train your own MuseGAN model to generate music in the style of the Bach chorales"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84acc7be-6764-4668-b2bb-178f63deeed3",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.style.use(\"seaborn-v0_8-colorblind\")\n",
    "\n",
    "import os\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras import (\n",
    "    layers,\n",
    "    models,\n",
    "    optimizers,\n",
    "    callbacks,\n",
    "    initializers,\n",
    "    metrics,\n",
    ")\n",
    "\n",
    "from musegan_utils import notes_to_midi, draw_score"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "339e6268-ebd7-4feb-86db-1fe7abccdbe5",
   "metadata": {},
   "source": [
    "## 0. Parameters <a name=\"parameters\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b2ee6ce-129f-4833-b0c5-fa567381c4e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE = 64\n",
    "\n",
    "N_BARS = 2\n",
    "N_STEPS_PER_BAR = 16\n",
    "MAX_PITCH = 83\n",
    "N_PITCHES = MAX_PITCH + 1\n",
    "Z_DIM = 32\n",
    "\n",
    "\n",
    "CRITIC_STEPS = 5\n",
    "GP_WEIGHT = 10\n",
    "CRITIC_LEARNING_RATE = 0.001\n",
    "GENERATOR_LEARNING_RATE = 0.001\n",
    "ADAM_BETA_1 = 0.5\n",
    "ADAM_BETA_2 = 0.9\n",
    "EPOCHS = 6000\n",
    "LOAD_MODEL = False"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7d4f5e63-e36a-4dc8-9f03-cb29c1fa5290",
   "metadata": {},
   "source": [
    "## 1. Prepare the Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10303bc9-1d3b-4fbe-a70a-eb1627ebdec1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the data\n",
    "file = os.path.join(\"/app/data/bach-chorales/Jsb16thSeparated.npz\")\n",
    "with np.load(file, encoding=\"bytes\", allow_pickle=True) as f:\n",
    "    data = f[\"train\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f81c0c2-58ab-4528-a813-dee601a3b020",
   "metadata": {},
   "outputs": [],
   "source": [
    "N_SONGS = len(data)\n",
    "print(f\"{N_SONGS} chorales in the dataset\")\n",
    "chorale = data[0]\n",
    "N_BEATS, N_TRACKS = chorale.shape\n",
    "print(f\"{N_BEATS, N_TRACKS} shape of chorale 0\")\n",
    "print(\"\\nChorale 0\")\n",
    "print(chorale[:8])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "faef759f-8e31-4ab8-a8e9-1250bd7913a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "two_bars = np.array([x[: (N_STEPS_PER_BAR * N_BARS)] for x in data])\n",
    "two_bars = np.array(np.nan_to_num(two_bars, nan=MAX_PITCH), dtype=int)\n",
    "two_bars = two_bars.reshape([N_SONGS, N_BARS, N_STEPS_PER_BAR, N_TRACKS])\n",
    "print(f\"Two bars shape {two_bars.shape}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c427e38-dc20-411a-9b34-88f5cf4273f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_binary = np.eye(N_PITCHES)[two_bars]\n",
    "data_binary[data_binary == 0] = -1\n",
    "data_binary = data_binary.transpose([0, 1, 2, 4, 3])\n",
    "print(f\"Data binary shape {data_binary.shape}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8aac7386-3226-4ef9-887d-78d599e0b7dc",
   "metadata": {
    "tags": []
   },
   "source": [
    "## 2. Build the GAN <a name=\"build\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6fb77d8e-92e1-410d-87b1-f8142807f58c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Some helper functions\n",
    "\n",
    "initializer = initializers.RandomNormal(mean=0.0, stddev=0.02)\n",
    "\n",
    "\n",
    "def conv(x, f, k, s, p):\n",
    "    x = layers.Conv3D(\n",
    "        filters=f,\n",
    "        kernel_size=k,\n",
    "        padding=p,\n",
    "        strides=s,\n",
    "        kernel_initializer=initializer,\n",
    "    )(x)\n",
    "    x = layers.LeakyReLU()(x)\n",
    "    return x\n",
    "\n",
    "\n",
    "def conv_t(x, f, k, s, a, p, bn):\n",
    "    x = layers.Conv2DTranspose(\n",
    "        filters=f,\n",
    "        kernel_size=k,\n",
    "        padding=p,\n",
    "        strides=s,\n",
    "        kernel_initializer=initializer,\n",
    "    )(x)\n",
    "    if bn:\n",
    "        x = layers.BatchNormalization(momentum=0.9)(x)\n",
    "\n",
    "    x = layers.Activation(a)(x)\n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ae10770-921c-4c61-8005-4b5774ddee51",
   "metadata": {},
   "outputs": [],
   "source": [
    "def TemporalNetwork():\n",
    "    input_layer = layers.Input(shape=(Z_DIM,), name=\"temporal_input\")\n",
    "    x = layers.Reshape([1, 1, Z_DIM])(input_layer)\n",
    "    x = conv_t(x, f=1024, k=(2, 1), s=(1, 1), a=\"relu\", p=\"valid\", bn=True)\n",
    "    x = conv_t(\n",
    "        x, f=Z_DIM, k=(N_BARS - 1, 1), s=(1, 1), a=\"relu\", p=\"valid\", bn=True\n",
    "    )\n",
    "    output_layer = layers.Reshape([N_BARS, Z_DIM])(x)\n",
    "    return models.Model(input_layer, output_layer)\n",
    "\n",
    "\n",
    "TemporalNetwork().summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97b0386b-96ed-4589-bdb3-d9220b745636",
   "metadata": {},
   "outputs": [],
   "source": [
    "def BarGenerator():\n",
    "    input_layer = layers.Input(shape=(Z_DIM * 4,), name=\"bar_generator_input\")\n",
    "\n",
    "    x = layers.Dense(1024)(input_layer)\n",
    "    x = layers.BatchNormalization(momentum=0.9)(x)\n",
    "    x = layers.Activation(\"relu\")(x)\n",
    "    x = layers.Reshape([2, 1, 512])(x)\n",
    "\n",
    "    x = conv_t(x, f=512, k=(2, 1), s=(2, 1), a=\"relu\", p=\"same\", bn=True)\n",
    "    x = conv_t(x, f=256, k=(2, 1), s=(2, 1), a=\"relu\", p=\"same\", bn=True)\n",
    "    x = conv_t(x, f=256, k=(2, 1), s=(2, 1), a=\"relu\", p=\"same\", bn=True)\n",
    "    x = conv_t(x, f=256, k=(1, 7), s=(1, 7), a=\"relu\", p=\"same\", bn=True)\n",
    "    x = conv_t(x, f=1, k=(1, 12), s=(1, 12), a=\"tanh\", p=\"same\", bn=False)\n",
    "\n",
    "    output_layer = layers.Reshape([1, N_STEPS_PER_BAR, N_PITCHES, 1])(x)\n",
    "\n",
    "    return models.Model(input_layer, output_layer)\n",
    "\n",
    "\n",
    "BarGenerator().summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4779bc1-8cfc-43fb-9fb3-1c521c3c9b9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def Generator():\n",
    "    chords_input = layers.Input(shape=(Z_DIM,), name=\"chords_input\")\n",
    "    style_input = layers.Input(shape=(Z_DIM,), name=\"style_input\")\n",
    "    melody_input = layers.Input(shape=(N_TRACKS, Z_DIM), name=\"melody_input\")\n",
    "    groove_input = layers.Input(shape=(N_TRACKS, Z_DIM), name=\"groove_input\")\n",
    "\n",
    "    # CHORDS -> TEMPORAL NETWORK\n",
    "    chords_tempNetwork = TemporalNetwork()\n",
    "    chords_over_time = chords_tempNetwork(chords_input)  # [n_bars, z_dim]\n",
    "\n",
    "    # MELODY -> TEMPORAL NETWORK\n",
    "    melody_over_time = [\n",
    "        None\n",
    "    ] * N_TRACKS  # list of n_tracks [n_bars, z_dim] tensors\n",
    "    melody_tempNetwork = [None] * N_TRACKS\n",
    "    for track in range(N_TRACKS):\n",
    "        melody_tempNetwork[track] = TemporalNetwork()\n",
    "        melody_track = layers.Lambda(lambda x, track=track: x[:, track, :])(\n",
    "            melody_input\n",
    "        )\n",
    "        melody_over_time[track] = melody_tempNetwork[track](melody_track)\n",
    "\n",
    "    # CREATE BAR GENERATOR FOR EACH TRACK\n",
    "    barGen = [None] * N_TRACKS\n",
    "    for track in range(N_TRACKS):\n",
    "        barGen[track] = BarGenerator()\n",
    "\n",
    "    # CREATE OUTPUT FOR EVERY TRACK AND BAR\n",
    "    bars_output = [None] * N_BARS\n",
    "    c = [None] * N_BARS\n",
    "    for bar in range(N_BARS):\n",
    "        track_output = [None] * N_TRACKS\n",
    "\n",
    "        c[bar] = layers.Lambda(lambda x, bar=bar: x[:, bar, :])(\n",
    "            chords_over_time\n",
    "        )  # [z_dim]\n",
    "        s = style_input  # [z_dim]\n",
    "\n",
    "        for track in range(N_TRACKS):\n",
    "            m = layers.Lambda(lambda x, bar=bar: x[:, bar, :])(\n",
    "                melody_over_time[track]\n",
    "            )  # [z_dim]\n",
    "            g = layers.Lambda(lambda x, track=track: x[:, track, :])(\n",
    "                groove_input\n",
    "            )  # [z_dim]\n",
    "\n",
    "            z_input = layers.Concatenate(\n",
    "                axis=1, name=\"total_input_bar_{}_track_{}\".format(bar, track)\n",
    "            )([c[bar], s, m, g])\n",
    "\n",
    "            track_output[track] = barGen[track](z_input)\n",
    "\n",
    "        bars_output[bar] = layers.Concatenate(axis=-1)(track_output)\n",
    "\n",
    "    generator_output = layers.Concatenate(axis=1, name=\"concat_bars\")(\n",
    "        bars_output\n",
    "    )\n",
    "\n",
    "    return models.Model(\n",
    "        [chords_input, style_input, melody_input, groove_input],\n",
    "        generator_output,\n",
    "    )\n",
    "\n",
    "\n",
    "generator = Generator()\n",
    "# generator.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19b5b32b-7be3-440f-8f4d-6acb91587468",
   "metadata": {},
   "outputs": [],
   "source": [
    "def Critic():\n",
    "    critic_input = layers.Input(\n",
    "        shape=(N_BARS, N_STEPS_PER_BAR, N_PITCHES, N_TRACKS),\n",
    "        name=\"critic_input\",\n",
    "    )\n",
    "\n",
    "    x = critic_input\n",
    "\n",
    "    x = conv(x, f=128, k=(2, 1, 1), s=(1, 1, 1), p=\"valid\")\n",
    "    x = conv(x, f=128, k=(N_BARS - 1, 1, 1), s=(1, 1, 1), p=\"valid\")\n",
    "    x = conv(x, f=128, k=(1, 1, 12), s=(1, 1, 12), p=\"same\")\n",
    "    x = conv(x, f=128, k=(1, 1, 7), s=(1, 1, 7), p=\"same\")\n",
    "    x = conv(x, f=128, k=(1, 2, 1), s=(1, 2, 1), p=\"same\")\n",
    "    x = conv(x, f=128, k=(1, 2, 1), s=(1, 2, 1), p=\"same\")\n",
    "    x = conv(x, f=256, k=(1, 4, 1), s=(1, 2, 1), p=\"same\")\n",
    "    x = conv(x, f=512, k=(1, 3, 1), s=(1, 2, 1), p=\"same\")\n",
    "\n",
    "    x = layers.Flatten()(x)\n",
    "\n",
    "    x = layers.Dense(1024, kernel_initializer=initializer)(x)\n",
    "    x = layers.LeakyReLU()(x)\n",
    "\n",
    "    critic_output = layers.Dense(\n",
    "        1, activation=None, kernel_initializer=initializer\n",
    "    )(x)\n",
    "\n",
    "    return models.Model(critic_input, critic_output)\n",
    "\n",
    "\n",
    "critic = Critic()\n",
    "# critic.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c75e79ac-d8cb-4000-bc9e-87ef25be4422",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MuseGAN(models.Model):\n",
    "    def __init__(self, critic, generator, latent_dim, critic_steps, gp_weight):\n",
    "        super(MuseGAN, self).__init__()\n",
    "        self.critic = critic\n",
    "        self.generator = generator\n",
    "        self.latent_dim = latent_dim\n",
    "        self.critic_steps = critic_steps\n",
    "        self.gp_weight = gp_weight\n",
    "\n",
    "    def compile(self, c_optimizer, g_optimizer):\n",
    "        super(MuseGAN, self).compile()\n",
    "        self.c_optimizer = c_optimizer\n",
    "        self.g_optimizer = g_optimizer\n",
    "        self.c_wass_loss_metric = metrics.Mean(name=\"c_wass_loss\")\n",
    "        self.c_gp_metric = metrics.Mean(name=\"c_gp\")\n",
    "        self.c_loss_metric = metrics.Mean(name=\"c_loss\")\n",
    "        self.g_loss_metric = metrics.Mean(name=\"g_loss\")\n",
    "\n",
    "    @property\n",
    "    def metrics(self):\n",
    "        return [\n",
    "            self.c_loss_metric,\n",
    "            self.c_wass_loss_metric,\n",
    "            self.c_gp_metric,\n",
    "            self.g_loss_metric,\n",
    "        ]\n",
    "\n",
    "    def gradient_penalty(self, batch_size, real_images, fake_images):\n",
    "        alpha = tf.random.normal([batch_size, 1, 1, 1, 1], 0.0, 1.0)\n",
    "        diff = fake_images - real_images\n",
    "        interpolated = real_images + alpha * diff\n",
    "\n",
    "        with tf.GradientTape() as gp_tape:\n",
    "            gp_tape.watch(interpolated)\n",
    "            pred = self.critic(interpolated, training=True)\n",
    "\n",
    "        grads = gp_tape.gradient(pred, [interpolated])[0]\n",
    "        norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))\n",
    "        gp = tf.reduce_mean((norm - 1.0) ** 2)\n",
    "        return gp\n",
    "\n",
    "    def train_step(self, real_images):\n",
    "        batch_size = tf.shape(real_images)[0]\n",
    "\n",
    "        for i in range(self.critic_steps):\n",
    "            chords_random_latent_vectors = tf.random.normal(\n",
    "                shape=(batch_size, self.latent_dim)\n",
    "            )\n",
    "            style_random_latent_vectors = tf.random.normal(\n",
    "                shape=(batch_size, self.latent_dim)\n",
    "            )\n",
    "            melody_random_latent_vectors = tf.random.normal(\n",
    "                shape=(batch_size, N_TRACKS, self.latent_dim)\n",
    "            )\n",
    "            groove_random_latent_vectors = tf.random.normal(\n",
    "                shape=(batch_size, N_TRACKS, self.latent_dim)\n",
    "            )\n",
    "\n",
    "            random_latent_vectors = [\n",
    "                chords_random_latent_vectors,\n",
    "                style_random_latent_vectors,\n",
    "                melody_random_latent_vectors,\n",
    "                groove_random_latent_vectors,\n",
    "            ]\n",
    "\n",
    "            with tf.GradientTape() as tape:\n",
    "                fake_images = self.generator(\n",
    "                    random_latent_vectors, training=True\n",
    "                )\n",
    "                fake_predictions = self.critic(fake_images, training=True)\n",
    "                real_predictions = self.critic(real_images, training=True)\n",
    "\n",
    "                c_wass_loss = tf.reduce_mean(fake_predictions) - tf.reduce_mean(\n",
    "                    real_predictions\n",
    "                )\n",
    "                c_gp = self.gradient_penalty(\n",
    "                    batch_size, real_images, fake_images\n",
    "                )\n",
    "                c_loss = c_wass_loss + c_gp * self.gp_weight\n",
    "\n",
    "            c_gradient = tape.gradient(c_loss, self.critic.trainable_variables)\n",
    "            self.c_optimizer.apply_gradients(\n",
    "                zip(c_gradient, self.critic.trainable_variables)\n",
    "            )\n",
    "\n",
    "        chords_random_latent_vectors = tf.random.normal(\n",
    "            shape=(batch_size, self.latent_dim)\n",
    "        )\n",
    "        style_random_latent_vectors = tf.random.normal(\n",
    "            shape=(batch_size, self.latent_dim)\n",
    "        )\n",
    "        melody_random_latent_vectors = tf.random.normal(\n",
    "            shape=(batch_size, N_TRACKS, self.latent_dim)\n",
    "        )\n",
    "        groove_random_latent_vectors = tf.random.normal(\n",
    "            shape=(batch_size, N_TRACKS, self.latent_dim)\n",
    "        )\n",
    "\n",
    "        random_latent_vectors = [\n",
    "            chords_random_latent_vectors,\n",
    "            style_random_latent_vectors,\n",
    "            melody_random_latent_vectors,\n",
    "            groove_random_latent_vectors,\n",
    "        ]\n",
    "\n",
    "        with tf.GradientTape() as tape:\n",
    "            fake_images = self.generator(random_latent_vectors, training=True)\n",
    "            fake_predictions = self.critic(fake_images, training=True)\n",
    "            g_loss = -tf.reduce_mean(fake_predictions)\n",
    "\n",
    "        gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)\n",
    "        self.g_optimizer.apply_gradients(\n",
    "            zip(gen_gradient, self.generator.trainable_variables)\n",
    "        )\n",
    "\n",
    "        self.c_loss_metric.update_state(c_loss)\n",
    "        self.c_wass_loss_metric.update_state(c_wass_loss)\n",
    "        self.c_gp_metric.update_state(c_gp)\n",
    "        self.g_loss_metric.update_state(g_loss)\n",
    "\n",
    "        return {m.name: m.result() for m in self.metrics}\n",
    "\n",
    "    def generate_piano_roll(self, num_scores):\n",
    "        chords_random_latent_vectors = tf.random.normal(\n",
    "            shape=(num_scores, Z_DIM)\n",
    "        )\n",
    "        style_random_latent_vectors = tf.random.normal(\n",
    "            shape=(num_scores, Z_DIM)\n",
    "        )\n",
    "        melody_random_latent_vectors = tf.random.normal(\n",
    "            shape=(num_scores, N_TRACKS, Z_DIM)\n",
    "        )\n",
    "        groove_random_latent_vectors = tf.random.normal(\n",
    "            shape=(num_scores, N_TRACKS, Z_DIM)\n",
    "        )\n",
    "        random_latent_vectors = [\n",
    "            chords_random_latent_vectors,\n",
    "            style_random_latent_vectors,\n",
    "            melody_random_latent_vectors,\n",
    "            groove_random_latent_vectors,\n",
    "        ]\n",
    "        generated_music = self.generator(random_latent_vectors)\n",
    "        generated_music = generated_music.numpy()\n",
    "        return generated_music"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff97c9b3-e6a7-4ed5-a09c-6ba15a9ac30e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a MuseGAN\n",
    "musegan = MuseGAN(\n",
    "    critic=critic,\n",
    "    generator=generator,\n",
    "    latent_dim=Z_DIM,\n",
    "    critic_steps=CRITIC_STEPS,\n",
    "    gp_weight=GP_WEIGHT,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b35900ab-af6e-41fb-81d5-03dc89c64022",
   "metadata": {},
   "outputs": [],
   "source": [
    "if LOAD_MODEL:\n",
    "    musegan.load_weights(\"./checkpoint/checkpoint.ckpt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e30cf5a8-b6cc-4e88-9e1c-4efdc28b82c4",
   "metadata": {},
   "source": [
    "## 3. Train the MuseGAN <a name=\"train\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5fe69ad5-371d-401b-94bd-8f6ddc8c97e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compile the GAN\n",
    "musegan.compile(\n",
    "    c_optimizer=optimizers.Adam(\n",
    "        learning_rate=CRITIC_LEARNING_RATE,\n",
    "        beta_1=ADAM_BETA_1,\n",
    "        beta_2=ADAM_BETA_2,\n",
    "    ),\n",
    "    g_optimizer=optimizers.Adam(\n",
    "        learning_rate=GENERATOR_LEARNING_RATE,\n",
    "        beta_1=ADAM_BETA_1,\n",
    "        beta_2=ADAM_BETA_2,\n",
    "    ),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8babab8-5eb1-4332-8e2e-4e937c9ef203",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a model save checkpoint\n",
    "model_checkpoint_callback = callbacks.ModelCheckpoint(\n",
    "    filepath=\"./checkpoint/checkpoint.ckpt\",\n",
    "    save_weights_only=True,\n",
    "    save_freq=\"epoch\",\n",
    "    verbose=0,\n",
    ")\n",
    "\n",
    "tensorboard_callback = callbacks.TensorBoard(log_dir=\"./logs\")\n",
    "\n",
    "\n",
    "class MusicGenerator(callbacks.Callback):\n",
    "    def __init__(self, num_scores):\n",
    "        self.num_scores = num_scores\n",
    "\n",
    "    def on_epoch_end(self, epoch, logs=None):\n",
    "        if epoch % 1 == 0:\n",
    "            generated_music = self.model.generate_piano_roll(self.num_scores)\n",
    "            notes_to_midi(\n",
    "                generated_music,\n",
    "                N_BARS,\n",
    "                N_TRACKS,\n",
    "                N_STEPS_PER_BAR,\n",
    "                filename=\"output_\" + str(epoch).zfill(4),\n",
    "            )\n",
    "            # display(generated_images, save_to = \"./output/generated_img_%03d.png\" % (epoch), cmap = None)\n",
    "\n",
    "\n",
    "music_generator_callback = MusicGenerator(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ace4abc9-dec5-49a8-a81c-ddcc5b13b13a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "history = musegan.fit(\n",
    "    data_binary,\n",
    "    epochs=EPOCHS,\n",
    "    callbacks=[\n",
    "        model_checkpoint_callback,\n",
    "        tensorboard_callback,\n",
    "        music_generator_callback,\n",
    "    ],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8959adf3-63f3-4364-95fa-e0a722d8de0e",
   "metadata": {},
   "source": [
    "# Generate new scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ebeebf4e-1fb1-4050-8a2b-479200bf7090",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "num_scores = 1\n",
    "chords_random_latent_vectors = np.random.normal(size=(num_scores, Z_DIM))\n",
    "style_random_latent_vectors = np.random.normal(size=(num_scores, Z_DIM))\n",
    "melody_random_latent_vectors = np.random.normal(\n",
    "    size=(num_scores, N_TRACKS, Z_DIM)\n",
    ")\n",
    "groove_random_latent_vectors = np.random.normal(\n",
    "    size=(num_scores, N_TRACKS, Z_DIM)\n",
    ")\n",
    "random_latent_vectors = [\n",
    "    chords_random_latent_vectors,\n",
    "    style_random_latent_vectors,\n",
    "    melody_random_latent_vectors,\n",
    "    groove_random_latent_vectors,\n",
    "]\n",
    "generated_music = generator(random_latent_vectors)\n",
    "generated_music = generated_music.numpy()\n",
    "\n",
    "draw_score(generated_music, 0)\n",
    "notes_to_midi(\n",
    "    generated_music, N_BARS, N_TRACKS, N_STEPS_PER_BAR, filename=\"output_midi\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d93121b1-1073-4aa0-b798-7561171afe6d",
   "metadata": {},
   "source": [
    "## Changing Chord Noise"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8a10be3-a50e-431a-af8b-18b609f1534c",
   "metadata": {},
   "outputs": [],
   "source": [
    "chords_random_latent_vectors_2 = np.random.normal(size=(num_scores, Z_DIM))\n",
    "random_latent_vectors_2 = [\n",
    "    chords_random_latent_vectors_2,\n",
    "    style_random_latent_vectors,\n",
    "    melody_random_latent_vectors,\n",
    "    groove_random_latent_vectors,\n",
    "]\n",
    "generated_music_2 = generator(random_latent_vectors_2)\n",
    "generated_music_2 = generated_music_2.numpy()\n",
    "draw_score(generated_music_2, 0)\n",
    "notes_to_midi(\n",
    "    generated_music_2,\n",
    "    N_BARS,\n",
    "    N_TRACKS,\n",
    "    N_STEPS_PER_BAR,\n",
    "    filename=\"output_midi_chords_changed\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e663aa28-71a4-4297-88d0-6e88e625827e",
   "metadata": {},
   "source": [
    "# Changing Style Noise"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0bdf5c24-bde6-44b5-ac7d-1aa8a611a4da",
   "metadata": {},
   "outputs": [],
   "source": [
    "style_random_latent_vectors_2 = np.random.normal(size=(num_scores, Z_DIM))\n",
    "random_latent_vectors_3 = [\n",
    "    chords_random_latent_vectors,\n",
    "    style_random_latent_vectors_2,\n",
    "    melody_random_latent_vectors,\n",
    "    groove_random_latent_vectors,\n",
    "]\n",
    "generated_music_3 = generator(random_latent_vectors_3)\n",
    "generated_music_3 = generated_music_3.numpy()\n",
    "draw_score(generated_music_3, 0)\n",
    "notes_to_midi(\n",
    "    generated_music_3,\n",
    "    N_BARS,\n",
    "    N_TRACKS,\n",
    "    N_STEPS_PER_BAR,\n",
    "    filename=\"output_midi_style_changed\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "375319d2-7e91-4f8a-80a5-cb473c9e5e2c",
   "metadata": {},
   "source": [
    "## Changing Melody Noise"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0775b325-ac94-41fa-b354-ceb6202635be",
   "metadata": {},
   "outputs": [],
   "source": [
    "melody_random_latent_vectors_2 = np.copy(melody_random_latent_vectors)\n",
    "melody_random_latent_vectors_2[:, 0, :] = np.random.normal(\n",
    "    size=(num_scores, Z_DIM)\n",
    ")\n",
    "\n",
    "random_latent_vectors_4 = [\n",
    "    chords_random_latent_vectors,\n",
    "    style_random_latent_vectors,\n",
    "    melody_random_latent_vectors_2,\n",
    "    groove_random_latent_vectors,\n",
    "]\n",
    "generated_music_4 = generator(random_latent_vectors_4)\n",
    "generated_music_4 = generated_music_4.numpy()\n",
    "draw_score(generated_music_4, 0)\n",
    "notes_to_midi(\n",
    "    generated_music_4,\n",
    "    N_BARS,\n",
    "    N_TRACKS,\n",
    "    N_STEPS_PER_BAR,\n",
    "    filename=\"output_midi_melody_changed\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fa3a187e-fd3c-4922-90b5-5c0abd82d7f7",
   "metadata": {},
   "source": [
    "## Changing groove noise"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11834f53-4188-434c-bfe9-8e3858a68889",
   "metadata": {},
   "outputs": [],
   "source": [
    "groove_random_latent_vectors_2 = np.copy(groove_random_latent_vectors)\n",
    "groove_random_latent_vectors_2[:, -1, :] = np.random.normal(\n",
    "    size=(num_scores, Z_DIM)\n",
    ")\n",
    "\n",
    "random_latent_vectors_5 = [\n",
    "    chords_random_latent_vectors,\n",
    "    style_random_latent_vectors,\n",
    "    melody_random_latent_vectors,\n",
    "    groove_random_latent_vectors_2,\n",
    "]\n",
    "generated_music_5 = generator(random_latent_vectors_5)\n",
    "generated_music_5 = generated_music_5.numpy()\n",
    "draw_score(generated_music_5, 0)\n",
    "notes_to_midi(\n",
    "    generated_music_5,\n",
    "    N_BARS,\n",
    "    N_TRACKS,\n",
    "    N_STEPS_PER_BAR,\n",
    "    filename=\"output_midi_groove_changed\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e68165d-8da7-4c8a-a325-ee9fe90ceda4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
